# coding=utf-8
# Copyright 2020 The TensorFlow GAN Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# python2 python3
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from absl import flags
import tensorflow.compat.v1 as tf

from tensorflow_gan.examples.progressive_gan import train

FLAGS = flags.FLAGS


def provide_random_data(batch_size=2, patch_size=4, colors=1, **unused_kwargs):
  return tf.random.normal([batch_size, patch_size, patch_size, colors])


class TrainTest(tf.test.TestCase):

  def setUp(self):
    super(TrainTest, self).setUp()
    self._config = {
        'start_height': 2,
        'start_width': 2,
        'scale_base': 2,
        'num_resolutions': 2,
        'batch_size_schedule': [2],
        'colors': 1,
        'to_rgb_use_tanh_activation': True,
        'kernel_size': 3,
        'stable_stage_num_images': 1,
        'transition_stage_num_images': 1,
        'total_num_images': 3,
        'save_summaries_num_images': 2,
        'latent_vector_size': 2,
        'fmap_base': 8,
        'fmap_decay': 1.0,
        'fmap_max': 8,
        'gradient_penalty_target': 1.0,
        'gradient_penalty_weight': 10.0,
        'real_score_penalty_weight': 0.001,
        'generator_learning_rate': 0.001,
        'discriminator_learning_rate': 0.001,
        'adam_beta1': 0.0,
        'adam_beta2': 0.99,
        'fake_grid_size': 2,
        'interp_grid_size': 2,
        'train_log_dir': os.path.join(FLAGS.test_tmpdir, 'progressive_gan'),
        'master': '',
        'task': 0
    }

  def test_train_success(self):
    if tf.executing_eagerly():
      # `tfgan.gan_model` doesn't work when executing eagerly.
      return
    train_log_dir = self._config['train_log_dir']
    if not tf.io.gfile.exists(train_log_dir):
      tf.io.gfile.makedirs(train_log_dir)

    for stage_id in train.get_stage_ids(**self._config):
      batch_size = train.get_batch_size(stage_id, **self._config)
      tf.reset_default_graph()
      real_images = provide_random_data(batch_size=batch_size)
      model = train.build_model(stage_id, batch_size, real_images,
                                **self._config)
      train.add_model_summaries(model, **self._config)
      train.train(model, **self._config)

  def test_get_batch_size(self):
    config = {'num_resolutions': 5, 'batch_size_schedule': [8, 4, 2]}
    # batch_size_schedule is expanded to [8, 8, 8, 4, 2]
    # At stage level it is [8, 8, 8, 8, 8, 4, 4, 2, 2]
    for i, expected_batch_size in enumerate([8, 8, 8, 8, 8, 4, 4, 2, 2]):
      self.assertEqual(train.get_batch_size(i, **config), expected_batch_size)


if __name__ == '__main__':
  tf.test.main()
